Multi-objective learning and explanation for stroke risk assessment in Shanxi province

Stroke is the leading cause of death in China (Zhou et al. in The Lancet, 2019). A dataset from Shanxi Province is analyzed to predict the risk of patients at four states (low/medium/high/attack) and to estimate transition probabilities between various states via a SHAP DeepExplainer. To handle the issues related to an imbalanced sample set, the quadratic interactive deep model (QIDeep) was first proposed by flexible selection and appending of quadratic interactive features. The experimental results showed that the QIDeep model with 3 interactive features achieved the state-of-the-art accuracy 83.33%(95% CI (83.14%; 83.52%)). Blood pressure, physical inactivity, smoking, weight, and total cholesterol are the top five most important features. For the sake of high recall in the attack state, stroke occurrence prediction is considered an auxiliary objective in multi-objective learning. The prediction accuracy was improved, while the recall of the attack state was increased by 17.79% (to 82.06%) compared to QIDeep (from 71.49%) with the same features. The prediction model and analysis tool in this paper provided not only a prediction method but also an attribution explanation of the risk states and transition direction of each patient, a valuable tool for doctors to analyze and diagnose the disease.

Stroke is the leading cause of death in China, and worldwide the second leading cause of death in people older than 60 years and the fifth leading cause of death among those aged 15 to 59 years old 1 . More than two-thirds of stroke deaths occur in developing countries, almost one-third occurring in China 2 . By 2017, stroke became one of the top three causes of death 3 , and accounted for 1.57 million deaths in 2018 4 . The rising number of stroke patients has put immense pressure on the public health system.
Stroke is a preventable disease. A certain number of potential risk factors, such as age, systolic blood pressure, and smoking, have been identified 5 , which provided some useful information for the general public. However, better algorithms are needed to improve the accuracy of predicting the risk of stroke and increase the effectiveness of preventive measures 6 . Traditional statistical methods, such as the Framingham Stroke Risk Profile (FSRP) 7 , new FSRP 8 , and QStroke 9 , aim to predict either 10-year or 5-year risk of stroke 10 . But the performance of the algorithms depends heavily on the preselected features.
Machine learning (ML) techniques comprise a set of powerful algorithms that are capable of modeling complex and hidden relationships between a multitude of clinical variables and the desired clinical outcome without stringent statistical assumptions. There is a growing interest in the application of machine learning techniques to address clinical problems [11][12][13][14] . Unlike the steady growth in the application of ML methods in other industries, the utilization of the ML approach, especially the use of deep learning in electronic health records (EHRs) appears only recently. A high-performance DNN model 15 was developed to predict 3-year and 8-year stroke occurrences based on a large EHR dataset. It was also demonstrated in [16][17][18] that DNN performed better than logistic regression (LR) and random forest (RF) for predicting the long-term outcome of stroke occurrence. Furthermore, the combination of real-time electromyography biosignals data 19 and long short-term memory (LSTM) algorithm with a balance the memory ratio between records (long-term) and real-time data (short-term) has showed some promises. However, most of these methods are mostly focused on two-states classification, i.e., Stroke/Non-stroke.  20 for predicting three risk states (low/medium/high), and stroke risk factors were ranked within the 8 + 2 group of main risk factors by the China National Stroke Prevention Project (CSPP), with hypertension, physical inactivity (lack of sports), and obesity as the top three in Shanxi.
In this paper, we use the 8 + 2 main risk factors to identify an additional risk state. The attack (most urgent) state is introduced to represent patients who have experienced at least one stroke, motivated by many studies that have considered stroke occurrence. Furthermore, we provide predictions and attributions of the current risk state and transitional probabilities for each patient. For example, when one patient is in a high-risk state, abnormal glycosylated hemoglobin and other factors make it possible for the patient to move into the attack state. This prompts us to add the attack state to the other three risk states considered in 20 .
The main contributions of this paper are described as follows.
• A DNN model with four-classification multilayer perceptrons (MLPs) is used a baseline for identifying risk states against the state-of-the-art tree models in healthcare (in "Experiment setup"). SHAP analysis tool is applied to obtain the determinants of each patient's current risk state and their transitional probabilities. The reminder of the paper is organized as follows. "Methods" section describes the proposed stroke risk prediction model and optimization methods, and results are presented in "Results" section, where implementation issues of the proposed method are also discussed. Concluding remarks are provided in "Conclusions" section.

Methods
In this section, two models and a model interpretation method are introduced for stroke risk prediction and intervention. The workflow of the aforementioned contents is demonstrated in Fig. 1. Figure 1a describes the process of QIDeep, which improves the performance of the stroke risk prediction from the base DNN. The model interpretation method SHAP takes the data and the trained DNN and returns feature importance. We select top-k important features to construct the order-2 features that are being fed to QIDeep (shown in Fig. 2) along with the original data. The multi-gate mixture-of-experts (MMOE) model, which further improves stroke risk prediction, is depicted in Fig. 1b. It takes the QIDeep and the base DNN as experts that contribute to both the stroke risk prediction and the stroke occurrence prediction, which serves as the auxiliary objective. In this section, we first briefly review the material and tools that are utilized. By considering the order-2 interactions, we propose QIDeep to improve the performance for stroke risk prediction. Last but not least, stroke occurrence prediction is considered an auxiliary objective to benefit from the multi-gate mixture-of-experts (MMOE) model to achieve a higher recall. The dataset applied in this study consists of survey data and laboratory data from Shanxi Province between 2017 and 2020. Adults who agreed to complete the CSPP survey and had resided in Shanxi Province for at least 6 months were invited to participate in the screening process. All participants were registered at the local government office. The study was conducted according to the Declaration of Helsinki. The original data included 27,583 residents from 2017 to 2020 and 2,000 hospitalized stroke-attacked patients in 2018. After cleaning, 34 features are employed, including sex, age, and other basic features; smoking, lack of exercise (physical inactivity), and other lifestyle features; and laboratory data such as blood pressure. Table 1 lists and explains some of the main features. After cleaning, the numbers of the four states (low:medium:high:attack) are 7221:5868:5475:1967. If the low-, medium-and highrisk data were merged into the non-stroke category, the sample ratio of stroke patients to non-stroke patients was nearly 1:10.

SHAP.
Feature importance refers to techniques that assign a score to input features based on their usefulness in predicting a target variable. Traditional feature importance methods can only rank the importance of different features, without explaining how the feature affects the predictions (positively or negatively). Proposed by Shapley in 1953 21 , the Shapley method is the only attribution method that satisfies the four attributes of efficiency, symmetry, dummy, and additivity, which can be regarded as the definition of fair expenditure. SHAP 22 is a "model interpretation" package developed by Python, which can interpret the output of most machine learning models. Inspired by Shapley, SHAP constructs an additive interpretation model, which interprets the predictive value of the model as the sum of the attribution values of each input feature. All features are regarded as "contributors", and the SHAP value is the average marginal contribution of the features to the output of the model.
Denote the i-th sample as x i and the j-th feature of the i-th sample as x ij . Let y i be the predicted value of the model for this sample, and let y base represent the baseline output of the model. The following equation then holds: is the contribution of the first feature in the i-th sample to the final predicted value y i . The value f (x i1 ) > 0 indicates that the feature improves the predicted value with a positive effect. On contrary, a negative value of f (x i1 ) indicates that the feature has a negative effect.
QIDeep. The main goal is to improve the model performance of predicting stroke risk with a small training dataset. Thus, popular methods combining all possible pairs of features, such as DeepFM, bring more computational burden and may not work well with the stroke risk assessment. The convolution methods are well known for their capacity to extract the local correlation and are popular in computer visions, recently, studies of imple-  24 , while direct usage of DeepFM may fail to converge due to a large number of parameters in our small dataset. As shown in Fig. 2, similar to factorization machines (FM) 25 , the output of the QI component is the number of inner product units: where n is the number of combination features in the QI layer, k is the length of the latent vector, V i is the latent vector representation of feature x i , and v i,l is the value of feature i at the l-th position of the latent vector.
The main advantage of QIDeep is that by adding order-2 interactions into the multiclassification model, it could flexibly control the number of combined features, balancing the number of parameters and sample size to ensure model convergence with a small dataset.

MMOE.
The recall of the attack state is more important for clinical diagnosis since ideally every urgent case should be predicted correctly. MMOE model 26 , one of the most popular methods of multi-objective learning, is selected to further refine the attack state without obviously increasing the number of model parameters.
Stein's paradox 27 in statistics states that estimating the means of three or more Gaussian random variables using samples from all of them could yield better performance than estimating them separately. Multi-objective learning refers to learning and optimizing multiple tasks simultaneously through the benefit of common information and specific information among tasks.
As shown in Fig. 1b we apply the MMOE model to simultaneously optimize the predictions of both stroke occurrence and stroke risk. The intermediate layers of DNN and QIDeep are embedded as Expert1 and Expert2, respectively. Combined with weights produced by gates, we input the weighted sum of comprehensive opinions of experts to towers to learn the characteristics of each objective. Tower1 returns the stroke occurrence prediction, and Tower2 outputs the stroke risk prediction. We note that stroke occurrence prediction is considered an auxiliary objective to attract more attention to the attack state 28 .
By MMOE, both objectives can be achieved through common information and specific information. Concretely, in Fig. 1b, for objective k ∈ {1, 2} , h k stands for the mapping to predict objective k , , is the weighted sum of each expert by the corresponding gate. It indicates that the gating networks of the k-th objective realize the selective utilization of experts by different weights. The n expert networks are denoted as {f i } n i=1 , and g i is the weight of the i th expert on the final decision that satisfies n i=1 g i (x) = 1 . In the learning of different objectives, the effects of different experts' opinions are adjusted through the gates.
The loss function is expressed as follows: where h 1 and h 2 are mappings to predict the stroke occurrence prediction and stroke risk prediction, respectively; l 1 is the binary cross entropy function, l 2 is the cross entropy function; and y and z are the labels of the stroke occurrence prediction and stroke risk prediction, respectively.

Results
In this section, we compare the performance between the proposed QIDeep model and several state-of-the-art models and show the proceeding improvement of the attack state by MMOE. In the numerical experiments, we apply Adam with a learning rate of 0.01 to optimize the models, and the early stop mechanism with patience 20 is used to prevent overfitting.
Evaluation metrics. We use four evaluation metrics in our experiments: • Accuracy the proportion of the number of correctly predicted samples to the total number of samples. Accuracy is the most intuitive indicator to measure the quality of the model.   Figure 3a is the feature    We use the high-risk state explainer as an example. Figure 3b,c describe the interactions of Exs-TG and Sm-LDBP, respectively. The normal ranges of the indicators in Table 1 are also illustrated in these figures (green dotted lines). We note that when the indicators, sorted along the x-axis are not in their normal range, lack of exercise (y = 1) will elevate patients to a high-risk state owing to greater SHAP values. Nonsmoking does not affect high-risk patients, while smoking increases the probability of high-risk patients, especially when blood pressure increases.
Numerical results. The improvement results are shown in Table 3. For the QIDeep model, the number of features, shown as 34+N, means that the features are composed of 34 original features and N order-2 interaction features. It can be seen from the table that with an increase in number of order-2 interaction features, the mean of the recall decreased slightly, while the model converges more rapidly by observing the drop in iteration numbers. Specifically, we list the recall value of the attack state for each feature set. It seems that QIDeep with three order-2 interaction features, whose expected recall is 78.81% (95% CI (78.13%, 79.50%)) achieves the best overall performance.
To increase the interpretability and practicability of the model, SHAP DeepExplainer for the QIDeep model is established to identify the dominant risk factors leading to state transition for each subject. Here we use two examples to explain how SHAP DeepExplainer works.
In the first example, we randomly select a sample of high-risk individuals. The predicted value for the four states is (0,0,3.6051,0), which means that the predicted state is high-risk with a maximum score of 3.6051, and no state transition potential (although according to (2), the predicted value of the four states should sum up to 1, in the numerical experiment, we use the value before softmax due to the direct use of the CELoss of PyTorch to avoid a vanishing gradient). Through the single sample analysis tool of SHAP, we obtain four explainers that correspond to the four states. These four explainers take a single sample as input and provide the force graph of each state. Figure 4a shows the force graph obtained by explainers corresponding to the high-risk state. In the figure, Shapley values of feature attributes are visualized as "forces", and each feature value is a force that increases or decreases the value of the prediction. The prediction starts from the base value, which is the average of the predictions, and each Shapley value is presented as an arrow to increase (red) or decrease (blue) the prediction. Figure 4a describes that smoking behavior (even worse than 20 years of smoking), lack of exercise, higher BMI, and abnormal indicators, such as blood pressure, elevate this patient to a higher risk state, while no family history of stroke prevents the patient from the high-risk state. Figure 4b,c is another example of an analysis of high-risk patients. The predictive value for the four states is (0,0,0.40,0.21), which means that the patient tends to change from a high-risk state to an attack state. Combined with Fig. 4b,c, it can be seen that the state is at high-risk due to a family history of stroke and hypertension, current high blood pressure, and older age (Ret = 1). However, the patient does not smoke and continues exercising to ensure a normal BMI and TC, which reduces the risk of stroke. In particular, we can see that the patient has a higher level of education and suggest that better understanding and higher attention to the disease would probably reduce the risk of stroke attack. This finding is consistent with the conclusion of the previous study 2 , i.e., the prevention of stroke can be accomplished by better and earlier treatment of hypertension and health education. Combined with the analysis results, we suggest that this patient should adopt a light diet (Flv) and take drugs to ensure normal blood pressure and cholesterol levels.
MMOE: proceeding improvement of attack state. Model selection for auxiliary objective. Table 4 shows the prediction results of current popular algorithms for stroke occurrence prediction on our dataset. Apart from the abovementioned indicators, the AUC, the area under the ROC curve, is considered to evaluate  32 . The AUC has the role of indicating whether the sorting is correct, i.e., whether the scores of positive samples are greater than those of negative samples because it considers both sensitivity and specificity 33 , and does not depend on the selection of thresholds. As shown in Table 4, Because DNN-B (B denotes binary to distinguish DNN-B from the base DNN model for four classifications) is the best model with the evaluation of all indicators, it is chosen as Expert1 of MMOE. Compared with LR, gradient harmonizing mechanism logistic regression (GHMLR) 34 outperforms the imbalanced sample set of stroke occurrence (1:10, as we mentioned) by using gradient density as hedging to disharmonies between different examples. RF is greater than GHMLR as its precision, recall, and f1 score are better with similar AUC values. Therefore, the AUC could not be treated as a unique evaluating indicator in our project.
Numerical results. For the MMOE model, the attach occurrence prediction is objective 1, and stroke risk assessment is objective 2. We emphasize that we select the top 20 features from the original 34 features as the input to ensure convergence. The MMOE outperforms the single objective model of both objectives. For attach occurrence prediction, the average recall is increased by 5.68% (from 84.75 to 89.56%), and the AUC is increased by 1.05% (from 97.10 to 98.12%) with similar precision (DNN: 94.94%, MMOE: 94.43%). Consider the stroke risk prediction, the overall prediction accuracy is 84.51% (95% CI (84.17%, 84.84%)) with a feature structure of 20 + 3, which is 1.37% higher than that of the QIDeep model with a feature structure of 20 + 3 (83.14% (95% CI (82.22%, 84.05))) and even 1.18% higher than that of the QIDeep model with a feature structure of 34 + 3 (83.33%). The contents of QIDeep and MMOE of Table 2 show that the prediction performance (precision, recall, and F1 score) of each state is improved by MMOE. We believe that when the dataset is large enough, the two-objective optimization of stroke occurrence and stroke risk prediction based on all features of the MMOE framework would achieve better results.
With the promise of improving the overall effect of the model, we further compare the prediction results of each method for the attack state, as shown in Table 5. This approach is also the motivation for using the multiobjective model. In this case, the base DNN model uses 20 original features, and QIDeep and MMOE use a feature structure of 20 + 3. It can be seen from Table 5    Managerial implication. Traditionally, laboratory data is collected by the doctor during the diagnosis. The stroke risk state is estimated according to the doctor's personal experience along with the 8+2 rules from the CSPP. Generally, the doctors cannot attribute the risk state to more detailed factors. This paper aims to approach personalized and intelligence diagnostics. Through accuracy prediction models and their explainers, the proposed methods produce the main factors affecting the current risk state and the transfer trends for each patient individually. The returns of the proposed methods can help the doctors to diagnose, and be demonstrated visually to the patients for a clear sense of the factors and the actions suggested to react, such as quitting smoking and exercising.

Conclusion
In view of the diagnosis and analysis of stroke, stroke risk assessment models have been proposed by deep learning methods. Moreover, the proposed models could identify the determinants of these risk states for every subject, which makes personalized treatment possible and improves the effectiveness of stroke intervention and prevention.
Specifically, for the accuracy of all stroke risk predictions, we propose a QIDeep model by adding quadratic interactive features to a DNN to address a small data set. This method can flexibly control the number of combination features, and balance model parameters and samples to ensure model convergence. In addition, aimed at the problem that the method is not effective for the important attack state, we use the multi-objective learning framework of MMOE for reference and build a model with the objective of optimizing both stroke occurrence prediction and stroke risk prediction. In the case of the same feature, the prediction effect of each risk state is improved, the accuracy of the stroke risk assessment was improved by 1.37% (from 83.14 to 84.51%), and the F1 score of the attack state was increased by 7.4% (from 78.99 to 84.87%) and the recall is increased by 14.8% (from 71.49 to 82.06%) compared with QIDeep of the single objective method. This method can be applied to the prediction of other diseases, where missing data are common and risk factors are not well understood.
Unfortunately, SHAP estimates the feature importance by linear approximation. The explainer can order the features by their attributions but the attributions are not precisely quantified. Hence, its application to precision medicine is not straightforward. In the future, we are going to improve the performance of stroke risk assessment by collecting data from multiple dimensions. To best benefit from the census data, we will try the meta-learning framework 35 . More precisely, the meta-knowledge of the intelligence diagnostics will be trained on the lifestyle data, while turning to the specific disease. The goal is to achieve few-shot convergence and accurate prediction on the small laboratory dataset.

Data availability
The implementation of the numerical experiment is available at https://github.com/MadelineMa/MTL4Strok-eAssessment. Despite the application to the stroke dataset, we also apply the models to the MNIST handwritten digit dataset (https://github.com/ MadelineMa/MTL4StrokeAssessment/tree/main/src/MNIST), which offers evidence of the generalization of the proposed methods.